import numpy as np


class Tensor:
    """
    计算节点
    """

    def __init__(self, data, depend=[], name='none') -> None:
        """
        data 节点值
        depend 当前节点的输入节点
        name 节点的名字
        """
        self.data = data
        self.depend = depend
        self.name = name
        self.grad = 0  # 初始梯度为0

    def __mul__(self, data):
        """
        左乘 y = x * data
        """
        def grad_fn1(grad):
            return grad * data.data

        def grad_fn2(grad):
            return grad * self.data
        return Tensor(
            self.data * data.data,  # 正向计算的值
            depend=[(self, grad_fn1), (data, grad_fn2)]
        )

    def __rmul__(self, data):
        """
        右乘 y = data * x
        """
        def grad_fn1(grad):
            return grad * self.data

        def grad_fn2(grad):
            return grad * data.data
        return Tensor(
            data.data * self.data,  # 正向计算的值
            depend=[(self, grad_fn1), (data, grad_fn2)]
        )

    def __add__(self, data):
        """
        左加 y = x + data
        """
        def grad_fn(grad):
            return grad
        return Tensor(
            self.data + data.data,  # 正向计算的值
            depend=[(self, grad_fn), (data, grad_fn)]
        )

    def __radd__(self, data):
        """
        右加 y = data + x
        """
        def grad_fn(grad):
            return grad
        return Tensor(
            data.data + self.data,  # 正向计算的值
            depend=[(self, grad_fn), (data, grad_fn)]
        )

    def __pow__(self, n):
        """
        幂运算 y = x ** n
        """
        def grad_fn(grad):
            return grad * n * self.data ** (n - 1)
        return Tensor(
            self.data ** n,  # 正向计算的值
            depend=[(self, grad_fn)]
        )

    def backward(self, grad=None):
        """
        反向求导
        """
        if grad == None:
            self.grad = 1
            grad = 1
        else:
            self.grad += grad
        # 递归计算每一个节点
        for tensor, grad_fn in self.depend:
            bw = grad_fn(grad)
            tensor.backward(bw)

    def zero_grad(self):
        """
        梯度清0
        """
        self.grad = 0
        # 递归清除每一个节点的梯度
        for tensor, _ in self.depend:
            tensor.zero_grad()


x = Tensor(2)
x2 = x * x
g = x2 * x2
h = x2 * x2
y = g + h
y.backward()
print(f'y关于g的导数 {g.grad}')
print(f'y关于x的导数 {x.grad}')
y.zero_grad()
g.backward()
print(f'g关于x的导数 {x.grad}')
